Link to the paper: [link]
1. Discrete-Time: Markov Chains and Time Reversal (Recap) Consider a data distribution with positive density p data p_{\text{data}} p data , a positive prior density p prior p_{\text{prior}} p prior , and a Markov chain with initial density p 0 = p data p_0 = p_{\text{data}} p 0 = p data on R d \mathbb{R}^d R d evolving according to positive transition densities p k + 1 ∣ k p_{k+1|k} p k + 1∣ k for k ∈ { 0 , … , N − 1 } k \in \{ 0, \dots, N-1 \} k ∈ { 0 , … , N − 1 } . By the Markov property, any x 0 : N = { x k } k = 0 N ∈ X = ( R d ) N + 1 x_{0:N} = \{ x_k \}_{k=0}^N \in \mathcal{X} = (\mathbb{R}^d)^{N+1} x 0 : N = { x k } k = 0 N ∈ X = ( R d ) N + 1 , the joint density can be expressed as:
p ( x 0 : N ) = p 0 ( x 0 ) ∏ k = 0 N − 1 p k + 1 ∣ k ( x k + 1 ∣ x k ) \begin{equation}
p(x_{0:N}) = p_0(x_0) \prod_{k=0}^{N-1} p_{k+1 | k}(x_{k+1} | x_k)
\end{equation} p ( x 0 : N ) = p 0 ( x 0 ) k = 0 ∏ N − 1 p k + 1∣ k ( x k + 1 ∣ x k ) The joint density also admits the backward decomposition:
p ( x 0 : N ) = p N ( x N ) ∏ k = 0 N − 1 p k ∣ k + 1 ( x k ∣ x k + 1 ) \begin{equation}
p(x_{0:N}) = p_N(x_N) \prod_{k=0}^{N-1} p_{k|k+1}(x_k | x_{k+1})
\end{equation} p ( x 0 : N ) = p N ( x N ) k = 0 ∏ N − 1 p k ∣ k + 1 ( x k ∣ x k + 1 ) where p k ( x k ) = ∫ p k ∣ k − 1 ( x k ∣ x k − 1 ) p k − 1 ( x k − 1 ) d x k − 1 p_k(x_k) = \int p_{k|k-1}(x_k | x_{k-1}) p_{k-1}(x_{k-1}) dx_{k-1} p k ( x k ) = ∫ p k ∣ k − 1 ( x k ∣ x k − 1 ) p k − 1 ( x k − 1 ) d x k − 1 is the marginal density at step k ≥ 1 k \geq 1 k ≥ 1 .
☝
Why the backward decomposition is correct?
Prove by induction:
When N = 1 N=1 N = 1 , obviously it’s correct.
Assume that the equation holds for N = M , M ∈ N + N = M, M \in \mathbb{N}_+ N = M , M ∈ N + , then we prove it also holds for N = M + 1 N = M+1 N = M + 1 .
p ( x 0 : M + 1 ) = p ( x 0 : M ) p ( x M + 1 ∣ x 0 : M ) = [ p M ( x M ) ∏ k = 0 M − 1 p k ∣ k + 1 ( x k ∣ x k + 1 ) ] p ( x M + 1 ∣ x 0 : M ) (by assumption) = [ p M ( x M ) ∏ k = 0 M − 1 p k ∣ k + 1 ( x k ∣ x k + 1 ) ] p M + 1 ∣ M ( x M + 1 ∣ x M ) (by Markov Property) = p M + 1 ( x M + 1 ) p M ∣ M + 1 ( x M ∣ x M + 1 ) ∏ k = 0 M − 1 p k ∣ k + 1 ( x k ∣ x k + 1 ) (by Bayes Rule) = p M + 1 ( x M + 1 ) ∏ k = 0 M p k ∣ k + 1 ( x k ∣ x k + 1 ) \begin{align*}
p(x_{0:M+1}) &= p(x_{0:M}) p(x_{M+1} | x_{0:M}) \\
&= \left[ p_M(x_M) \prod_{k=0}^{M-1} p_{k|k+1}(x_k | x_{k+1}) \right] p(x_{M+1} | x_{0:M}) \quad \text{(by assumption)} \\
&= \left[ p_M(x_M) \prod_{k=0}^{M-1} p_{k|k+1}(x_k | x_{k+1}) \right] p_{M+1|M}(x_{M+1} | x_M) \quad \text{(by Markov Property)} \\
&= p_{M+1}(x_{M+1}) p_{M | M + 1}(x_M | x_{M+1}) \prod_{k=0}^{M-1} p_{k|k+1}(x_k | x_{k+1}) \quad \text{(by Bayes Rule)} \\
&= p_{M+1}(x_{M+1}) \prod_{k=0}^{M} p_{k|k+1}(x_k | x_{k+1})
\end{align*} p ( x 0 : M + 1 ) = p ( x 0 : M ) p ( x M + 1 ∣ x 0 : M ) = [ p M ( x M ) k = 0 ∏ M − 1 p k ∣ k + 1 ( x k ∣ x k + 1 ) ] p ( x M + 1 ∣ x 0 : M ) (by assumption) = [ p M ( x M ) k = 0 ∏ M − 1 p k ∣ k + 1 ( x k ∣ x k + 1 ) ] p M + 1∣ M ( x M + 1 ∣ x M ) (by Markov Property) = p M + 1 ( x M + 1 ) p M ∣ M + 1 ( x M ∣ x M + 1 ) k = 0 ∏ M − 1 p k ∣ k + 1 ( x k ∣ x k + 1 ) (by Bayes Rule) = p M + 1 ( x M + 1 ) k = 0 ∏ M p k ∣ k + 1 ( x k ∣ x k + 1 ) End of Proof
For the purpose of generative modeling, the transition densities are chosen such that p N ( x N ) = ∫ p ( x 0 : N ) d x 0 : N − 1 ≈ p prior ( x N ) p_N(x_N) = \int p(x_{0:N}) dx_{0:N-1} \approx p_{\text{prior}} (x_N) p N ( x N ) = ∫ p ( x 0 : N ) d x 0 : N − 1 ≈ p prior ( x N ) for large N N N , where p prior p_{\text{prior}} p prior is an easy-to-sample prior density. To sample approximately from p data p_{\text{data}} p data , one may use ancestral sampling with Equation (2), i.e. first sample X N ∼ p prior X_N \sim p_{\text{prior}} X N ∼ p prior followed by X k ∼ p k ∣ k + 1 ( ⋅ ∣ X k + 1 ) X_k \sim p_{k|k+1} (\cdot | X_{k+1}) X k ∼ p k ∣ k + 1 ( ⋅ ∣ X k + 1 ) for k ∈ { N − 1 , ⋯ , 0 } k \in \{ N-1, \cdots, 0 \} k ∈ { N − 1 , ⋯ , 0 } .
Equation (2) cannot be simulated exactly, but may be approximated if we consider a forward transition density of the form:
p k + 1 ∣ k ( x k + 1 ∣ x k ) = N ( x k + 1 ; x k + γ k + 1 f ( x k ) , 2 γ k + 1 I ) \begin{equation}
p_{k+1 | k} (x_{k+1} | x_k) = \mathcal{N} (x_{k+1}; x_k + \gamma_{k+1} f(x_k), 2 \gamma_{k+1} \bold{I})
\end{equation} p k + 1∣ k ( x k + 1 ∣ x k ) = N ( x k + 1 ; x k + γ k + 1 f ( x k ) , 2 γ k + 1 I ) with drift f : R d → R d f : \mathbb{R}^d \rightarrow \mathbb{R}^d f : R d → R d and stepsize γ k + 1 > 0 \gamma_{k+1} > 0 γ k + 1 > 0 . Equation (2) can be first approximated by the following equation:
p k ∣ k + 1 ( x k ∣ x k + 1 ) = p k + 1 ∣ k ( x k + 1 ∣ x k ) exp [ log p k ( x k ) − log p k + 1 ( x k + 1 ) ] ≈ N ( x k ; x k + 1 − γ k + 1 f ( x k + 1 ) + 2 γ k + 1 ∇ log p k + 1 ( x k + 1 , 2 γ k + 1 I ) \begin{equation}
\begin{align*}
p_{k | k+1}(x_k | x_{k+1})
&= p_{k+1 | k}(x_{k+1} | x_k) \exp [\log p_k(x_k) - \log p_{k+1} (x_{k+1})] \\
&\approx \mathcal{N}(x_k; x_{k+1} - \gamma_{k+1} f(x_{k+1}) + 2 \gamma_{k+1}\nabla \log p_{k+1} (x_{k+1}, 2\gamma_{k+1} \bold{I})
\end{align*}
\end{equation} p k ∣ k + 1 ( x k ∣ x k + 1 ) = p k + 1∣ k ( x k + 1 ∣ x k ) exp [ log p k ( x k ) − log p k + 1 ( x k + 1 )] ≈ N ( x k ; x k + 1 − γ k + 1 f ( x k + 1 ) + 2 γ k + 1 ∇ log p k + 1 ( x k + 1 , 2 γ k + 1 I ) using that p k ≈ p k + 1 p_k \approx p_{k+1} p k ≈ p k + 1 , a Taylor expansion of log p k + 1 \log p_{k+1} log p k + 1 at x k + 1 x_{k+1} x k + 1 and f ( x k ) ≈ f ( x k + 1 ) f(x_k) \approx f(x_{k+1}) f ( x k ) ≈ f ( x k + 1 ) .
☝
How to derive the approximation?
First,
p k ∣ k + 1 ( x k ∣ x k + 1 ) = p k + 1 ∣ k ( x k + 1 ∣ x k ) p ( x k ) p ( x k + 1 ) = p k + 1 ∣ k ( x k + 1 ∣ x k ) exp ( log p ( x k ) p ( x k + 1 ) ) = p k + 1 ∣ k ( x k + 1 ∣ x k ) exp [ log p k ( x k ) − log p k + 1 ( x k + 1 ) ] \begin{align*}
p_{k | k+1} (x_k |x_{k+1}) &= p_{k+1 | k}(x_{k+1} | x_k) \frac{p(x_k)}{p(x_{k+1})} \\
&= p_{k+1 | k}(x_{k+1} | x_k) \exp\left( \log \frac{p(x_k)}{p(x_{k+1})} \right) \\
&= p_{k+1 | k}(x_{k+1} | x_k) \exp [\log p_k(x_k) - \log p_{k+1} (x_{k+1})]
\end{align*} p k ∣ k + 1 ( x k ∣ x k + 1 ) = p k + 1∣ k ( x k + 1 ∣ x k ) p ( x k + 1 ) p ( x k ) = p k + 1∣ k ( x k + 1 ∣ x k ) exp ( log p ( x k + 1 ) p ( x k ) ) = p k + 1∣ k ( x k + 1 ∣ x k ) exp [ log p k ( x k ) − log p k + 1 ( x k + 1 )] By Taylor expansion of log p k + 1 ( x ) \log p_{k+1}(x) log p k + 1 ( x ) at x = x k + 1 x = x_{k+1} x = x k + 1 :
log p k + 1 ( x ) ≈ log p k + 1 ( x k + 1 ) + ∇ log p k + 1 ( x k + 1 ) T ( x − x k + 1 ) \log p_{k+1}(x) \approx \log p_{k+1}(x_{k+1}) + \nabla \log p_{k+1}(x_{k+1})^T (x - x_{k+1}) log p k + 1 ( x ) ≈ log p k + 1 ( x k + 1 ) + ∇ log p k + 1 ( x k + 1 ) T ( x − x k + 1 ) Since p k ≈ p k + 1 p_k \approx p_{k+1} p k ≈ p k + 1 , we have log p k ( x k ) ≈ log p k + 1 ( x ) \log p_k (x_k) \approx \log p_{k+1}(x) log p k ( x k ) ≈ log p k + 1 ( x ) . Plug it into the above approximation:
log p k ( x k ) − log p k + 1 ( x k + 1 ) ≈ ∇ log p k + 1 ( x k + 1 ) T ( x k − x k + 1 ) \log p_k(x_k) - \log p_{k+1}(x_{k+1}) \approx \nabla \log p_{k+1}(x_{k+1})^T (x_k - x_{k+1}) log p k ( x k ) − log p k + 1 ( x k + 1 ) ≈ ∇ log p k + 1 ( x k + 1 ) T ( x k − x k + 1 ) Since p k + 1 ∣ k ( x k + 1 ∣ x k ) p_{k+1 | k} (x_{k+1} | x_k) p k + 1∣ k ( x k + 1 ∣ x k ) is a Gaussian, we have:
p k + 1 ∣ k ( x k + 1 ∣ x k ) exp [ log p k ( x k ) − log p k + 1 ( x k + 1 ) ] ∝ exp ( − 1 4 γ k + 1 [ x k + 1 − x k − γ f ( x k + 1 ) ] 2 ) ⋅ exp ( ∇ log p k + 1 ( x k + 1 ) T ( x k − x k + 1 ) ) ∝ exp ( − 1 4 γ k + 1 [ ∣ ∣ x k ∣ ∣ 2 − 2 ( x k + 1 − γ k + 1 f ( x k + 1 ) ) T x k − 4 γ k + 1 ∇ log p k + 1 ( x k + 1 ) T ( x k − x k + 1 ) ] ) ∝ exp ( − 1 4 γ k + 1 [ ∣ ∣ x k ∣ ∣ 2 − 2 x k T ( x k + 1 − γ k + 1 f ( x k + 1 ) + 2 γ k + 1 ∇ log p k + 1 ( x k + 1 ) ] ) \begin{align*}
&p_{k+1 | k}(x_{k+1} | x_k) \exp [\log p_k(x_k) - \log p_{k+1} (x_{k+1})] \\
\propto & \exp \left( -\frac{1}{4 \gamma_{k+1}} \left[ x_{k+1} - x_k - \gamma f(x_{k+1}) \right]^2 \right) \cdot \exp \left( \nabla \log p_{k+1}(x_{k+1})^T (x_k - x_{k+1}) \right) \\
\propto& \exp \left( -\frac{1}{4 \gamma_{k+1}} \left[ ||x_k||^2 - 2(x_{k+1} - \gamma_{k+1} f(x_{k+1}))^T x_k - 4 \gamma_{k+1} \nabla \log p_{k+1}(x_{k+1})^T (x_k - x_{k+1}) \right] \right) \\
\propto& \exp \left( -\frac{1}{4 \gamma_{k+1}} [||x_k||^2 - 2x_k^T (x_{k+1} - \gamma_{k+1} f(x_{k+1}) + 2\gamma_{k+1} \nabla \log p_{k+1}(x_{k+1})] \right)
\end{align*} ∝ ∝ ∝ p k + 1∣ k ( x k + 1 ∣ x k ) exp [ log p k ( x k ) − log p k + 1 ( x k + 1 )] exp ( − 4 γ k + 1 1 [ x k + 1 − x k − γ f ( x k + 1 ) ] 2 ) ⋅ exp ( ∇ log p k + 1 ( x k + 1 ) T ( x k − x k + 1 ) ) exp ( − 4 γ k + 1 1 [ ∣∣ x k ∣ ∣ 2 − 2 ( x k + 1 − γ k + 1 f ( x k + 1 ) ) T x k − 4 γ k + 1 ∇ log p k + 1 ( x k + 1 ) T ( x k − x k + 1 ) ] ) exp ( − 4 γ k + 1 1 [ ∣∣ x k ∣ ∣ 2 − 2 x k T ( x k + 1 − γ k + 1 f ( x k + 1 ) + 2 γ k + 1 ∇ log p k + 1 ( x k + 1 )] ) Since only Gaussian distribution has the quadratic exponential kernel, we conclude that:
p k ∣ k + 1 ( x k ∣ x k + 1 ) ≈ N ( x k ; x k + 1 − γ k + 1 f ( x k + 1 ) + 2 γ k + 1 ∇ log p k + 1 ( x k + 1 ) , 2 γ k + 1 I ) p_{k|k+1}(x_k | x_{k+1}) \approx \mathcal{N}(x_k ; x_{k+1} - \gamma_{k+1} f(x_{k+1}) + 2 \gamma_{k+1} \nabla \log p_{k+1}(x_{k+1}), 2 \gamma_{k+1} \bold{I}) p k ∣ k + 1 ( x k ∣ x k + 1 ) ≈ N ( x k ; x k + 1 − γ k + 1 f ( x k + 1 ) + 2 γ k + 1 ∇ log p k + 1 ( x k + 1 ) , 2 γ k + 1 I ) In practice, the approximation holds if ∣ ∣ x k + 1 − x k ∣ ∣ ||x_{k+1} - x_k|| ∣∣ x k + 1 − x k ∣∣ is small which is ensured by choosing γ k + 1 \gamma_{k+1} γ k + 1 small enough. ∇ log p k + 1 \nabla \log p_{k+1} ∇ log p k + 1 is not available, but one can obtain its approximation using denoising score-matching methods.
We assume that the conditional density p k + 1 ∣ 0 ( x k + 1 ∣ x 0 ) p_{k+1 | 0} (x_{k+1} | x_0) p k + 1∣0 ( x k + 1 ∣ x 0 ) is available analytically(e.g. gradient of Gaussian). We can show that ∇ log p k + 1 ( x k + 1 ) = E p 0 ∣ k + 1 [ ∇ x k + 1 log p k + 1 ∣ 0 ( x k + 1 ∣ X 0 ) ] \nabla \log p_{k+1}(x_{k+1}) = \mathbb{E}_{p_{0 | k+1}} \left[ \nabla_{x_{k+1}} \log p_{k+1 | 0} (x_{k+1} |X_0) \right] ∇ log p k + 1 ( x k + 1 ) = E p 0∣ k + 1 [ ∇ x k + 1 log p k + 1∣0 ( x k + 1 ∣ X 0 ) ] .
☝
Prove the Equality.
First, a fact about the gradient:
∇ log f ( x ) = ∇ f ( x ) f ( x ) f ( x ) ∇ log f ( x ) = ∇ f ( x ) \begin{align*}
\nabla \log f(x) &= \frac{\nabla f(x)}{f(x)} \\
f(x) \nabla \log f(x) &= \nabla f(x)
\end{align*} ∇ log f ( x ) f ( x ) ∇ log f ( x ) = f ( x ) ∇ f ( x ) = ∇ f ( x ) Since p k + 1 ( x k + 1 ) = ∫ p 0 ( x 0 ) p k + 1 ∣ 0 ( x k + 1 ∣ x 0 ) d x 0 p_{k+1} (x_{k+1}) = \int p_0(x_0) p_{k+1 | 0}(x_{k+1} |x_0) dx_0 p k + 1 ( x k + 1 ) = ∫ p 0 ( x 0 ) p k + 1∣0 ( x k + 1 ∣ x 0 ) d x 0 , we have:
∇ log p k + 1 ( x k + 1 ) = ∇ x k + 1 p k + 1 ( x k + 1 ) p k + 1 ( x k + 1 ) = ∫ p 0 ( x 0 ) p k + 1 ( x k + 1 ) ⋅ ∇ x k + 1 p k + 1 ∣ 0 ( x k + 1 ∣ x 0 ) d x 0 = ∫ p 0 ( x 0 ) p k + 1 ∣ 0 ( x k + 1 ∣ x 0 ) p k + 1 ( x k + 1 ) ⋅ ∇ x k + 1 log p k + 1 ∣ 0 ( x k + 1 ∣ x 0 ) d x 0 = ∫ p 0 ∣ k + 1 ( x 0 ∣ x k + 1 ) ⋅ ∇ x k + 1 log p k + 1 ∣ 0 ( x k + 1 ∣ x 0 ) d x 0 = E p 0 ∣ k + 1 [ ∇ x k + 1 log p k + 1 ∣ 0 ( x k + 1 ∣ x 0 ) ] \begin{align*}
\nabla \log p_{k+1}(x_{k+1})
&= \frac{\nabla_{x_{k+1}} p_{k+1}(x_{k+1})}{p_{k+1}(x_{k+1})} \\
&= \int \frac{p_0(x_0)}{p_{k+1}(x_{k+1})} \cdot \nabla_{x_{k+1}} p_{k+1 | 0}(x_{k+1} |x_0) dx_0 \\
&= \int \frac{p_0(x_0) p_{k+1|0}(x_{k+1} | x_0)}{p_{k+1}(x_{k+1})} \cdot \nabla_{x_{k+1}} \log p_{k+1 | 0}(x_{k+1} |x_0) dx_0 \\
&= \int p_{0 | k+1}(x_0 | x_{k+1}) \cdot \nabla_{x_{k+1}} \log p_{k+1 | 0}(x_{k+1} |x_0) dx_0 \\
&= \mathbb{E}_{p_{0 | k+1}} \left[ \nabla_{x_{k+1}} \log p_{k+1 | 0}(x_{k+1} |x_0)\right]
\end{align*} ∇ log p k + 1 ( x k + 1 ) = p k + 1 ( x k + 1 ) ∇ x k + 1 p k + 1 ( x k + 1 ) = ∫ p k + 1 ( x k + 1 ) p 0 ( x 0 ) ⋅ ∇ x k + 1 p k + 1∣0 ( x k + 1 ∣ x 0 ) d x 0 = ∫ p k + 1 ( x k + 1 ) p 0 ( x 0 ) p k + 1∣0 ( x k + 1 ∣ x 0 ) ⋅ ∇ x k + 1 log p k + 1∣0 ( x k + 1 ∣ x 0 ) d x 0 = ∫ p 0∣ k + 1 ( x 0 ∣ x k + 1 ) ⋅ ∇ x k + 1 log p k + 1∣0 ( x k + 1 ∣ x 0 ) d x 0 = E p 0∣ k + 1 [ ∇ x k + 1 log p k + 1∣0 ( x k + 1 ∣ x 0 ) ] Therefore we can formulate the score estimation as a regression problem and use a flexible class of functions, e.g. neural networks, to parameterize an approximation s θ ∗ ( k , x k ) ≈ ∇ log p k ( x k ) s_{\theta^*}(k, x_k) \approx \nabla \log p_k(x_k) s θ ∗ ( k , x k ) ≈ ∇ log p k ( x k ) such that:
θ ∗ = arg min θ ∑ k = 1 N E p 0 , k [ ∣ ∣ s θ ( k , X k ) − ∇ x k log p k ∣ 0 ( X k ∣ X 0 ) ∣ ∣ 2 ] \theta^* = \argmin_{\theta} \sum_{k=1}^N \mathbb{E}_{p_{0, k}} [||s_{\theta}(k, X_k) - \nabla_{x_k}\log p_{k|0}(X_k | X_0)||^2] θ ∗ = θ arg min k = 1 ∑ N E p 0 , k [ ∣∣ s θ ( k , X k ) − ∇ x k log p k ∣0 ( X k ∣ X 0 ) ∣ ∣ 2 ] where p 0 , k = p 0 ( x 0 ) p k ∣ 0 ( x k ∣ x 0 ) p_{0, k} = p_0(x_0) p_{k|0}(x_k | x_0) p 0 , k = p 0 ( x 0 ) p k ∣0 ( x k ∣ x 0 ) . This can be done by getting a sample x 0 x_0 x 0 from the dataset and use p k ∣ 0 p_{k|0} p k ∣0 to obtain the corresponding x k x_k x k .
If p k ∣ 0 p_{k|0} p k ∣0 is not available, we use θ ∗ = arg min θ ∑ k = 1 N E p k − 1 , k [ ∣ ∣ s θ ( k , X k ) − ∇ x k log p k ∣ k − 1 ( X k ∣ X k − 1 ) ∣ ∣ 2 ] \theta^* = \argmin_{\theta} \sum_{k=1}^N \mathbb{E}_{p_{k-1, k}}[||s_{\theta}(k, X_k) - \nabla_{x_k} \log p_{k | k-1}(X_k | X_{k-1}) ||^2] θ ∗ = arg min θ ∑ k = 1 N E p k − 1 , k [ ∣∣ s θ ( k , X k ) − ∇ x k log p k ∣ k − 1 ( X k ∣ X k − 1 ) ∣ ∣ 2 ] .
In Summary , Score-based Generative Modeling involves first estimating the score function s θ ∗ s_{\theta^*} s θ ∗ from noisy data, and the sampling X 0 X_0 X 0 using X N ∼ p prior X_N \sim p_{\text{prior}} X N ∼ p prior with ancestral sampling and approximation (Equation (4)), i.e.
X k = X k + 1 − γ k + 1 f ( X k + 1 ) + 2 γ k + 1 s θ ∗ ( k + 1 , X k + 1 ) + 2 γ k + 1 Z k + 1 \begin{equation}
X_k = X_{k+1} - \gamma_{k+1} f(X_{k+1}) + 2 \gamma_{k+1} s_{\theta^*}(k+1, X_{k+1}) + \sqrt{2 \gamma_{k+1}} Z_{k+1}
\end{equation} X k = X k + 1 − γ k + 1 f ( X k + 1 ) + 2 γ k + 1 s θ ∗ ( k + 1 , X k + 1 ) + 2 γ k + 1 Z k + 1 where Z k + 1 ∼ i.i.d N ( 0 , I ) Z_{k+1} \overset{\text{i.i.d}}{\sim} \mathcal{N}(0, \bold{I}) Z k + 1 ∼ i.i.d N ( 0 , I ) .
2. Continuous-Time: SDEs, Reverse-Time SDEs and Theoretical results The Markov chain with kernel in Equation (3) corresponds to an Euler-Maruyama discretization of ( X t ) t ∈ [ 0 , T ] (\bold{X}_t)_{t \in [0, T]} ( X t ) t ∈ [ 0 , T ] , solving the following SDE:
d X t = f ( X t ) d t + 2 d B t , X 0 ∼ p 0 = p data \begin{equation}
d \bold{X}_t = f(\bold{X}_t) dt + \sqrt{2} d \bold{B}_t, \qquad \bold{X}_0 \sim p_0 = p_{\text{data}}
\end{equation} d X t = f ( X t ) d t + 2 d B t , X 0 ∼ p 0 = p data where ( B t ) t ∈ [ 0 , T ] (\bold{B}_t)_{t \in [0, T]} ( B t ) t ∈ [ 0 , T ] is a Brownian motion and f : R d → R d f: \mathbb{R}^d \rightarrow \mathbb{R}^d f : R d → R d is regular enough so that solutions exists.
☝
Question on Equation (6)
If strictly follow Equation (3), the SDE should be:
d X t = γ t + 1 f ( X t ) d t + 2 γ t + 1 d B t d\bold{X}_t = \gamma_{t+1} f(\bold{X}_t) dt + \sqrt{2 \gamma_{t+1}} d \bold{B}_t d X t = γ t + 1 f ( X t ) d t + 2 γ t + 1 d B t The discretization gives,
X t + 1 − X t = γ t + 1 f ( X t ) + 2 γ t + 1 ϵ , ϵ ∼ N ( 0 , I ) ⇒ X t + 1 = X t + γ t + 1 f ( X t ) + 2 γ t + 1 ϵ = N ( X t + 1 ; X t + γ t + 1 f ( X t ) , 2 γ t + 1 I ) \bold{X}_{t+1} - \bold{X}_t = \gamma_{t+1} f(\bold{X}_t) + \sqrt{2 \gamma_{t+1}} \epsilon, \qquad \epsilon \sim \mathcal{N}(0, \bold{I}) \\
\Rightarrow \bold{X}_{t+1} = \bold{X}_t + \gamma_{t+1} f(\bold{X}_t) + \sqrt{2 \gamma_{t+1}} \epsilon = \mathcal{N} (\bold{X}_{t+1}; \bold{X}_t + \gamma_{t+1} f(\bold{X}_t), 2 \gamma_{t+1} \bold{I}) X t + 1 − X t = γ t + 1 f ( X t ) + 2 γ t + 1 ϵ , ϵ ∼ N ( 0 , I ) ⇒ X t + 1 = X t + γ t + 1 f ( X t ) + 2 γ t + 1 ϵ = N ( X t + 1 ; X t + γ t + 1 f ( X t ) , 2 γ t + 1 I ) Under some conditions on f f f , the reverse-time process ( Y t ) t ∈ [ 0 , T ] = ( X T − t ) t ∈ [ 0 , T ] (\bold{Y}_t)_{t \in [0, T]} = (\bold{X}_{T - t})_{t \in [0, T]} ( Y t ) t ∈ [ 0 , T ] = ( X T − t ) t ∈ [ 0 , T ] satisfies
d Y t = { − f ( Y t ) + 2 ∇ log p T − t ( Y t ) } d t + 2 d B t \begin{equation}
\text{d} \bold{Y}_t = \{ -f(\bold{Y}_t) + 2 \nabla \log p_{T-t}(\bold{Y}_t) \} \text{d} t + \sqrt{2} \text{d} \bold{B}_t
\end{equation} d Y t = { − f ( Y t ) + 2∇ log p T − t ( Y t )} d t + 2 d B t with initialization Y 0 ∼ p T \bold{Y}_0 \sim p_T Y 0 ∼ p T , where p t p_t p t denotes the marginal density of X t \bold{X}_t X t .
☝
Another notation in Yang Song’s paper .
In another paper, the reverse-time SDE is denoted as:
d x = [ f ( x , t ) − g ( t ) 2 ∇ x log p t ( x ) ] d t + g ( t ) d w ˉ \text{d} \bold{x} = [\bold{f}(\bold{x}, t) - g(t)^2 \nabla_{\bold{x}} \log p_t(\bold{x})] \text{d} t + g(t) \text{d} \bar{\bold{w}} d x = [ f ( x , t ) − g ( t ) 2 ∇ x log p t ( x )] d t + g ( t ) d w ˉ where the time flows backwards from T to 0 with initialization x T ∼ p T \bold{x}_T \sim p_T x T ∼ p T .
Two notations are equivalent since the flow direction of time are reversed.
The reverse-time Markov chain { Y k } k = 0 N \{ Y_k \}_{k=0}^N { Y k } k = 0 N associated with Equation (5) corresponds to an Euler-Maruyama discretization of Equation (7), where the score function are approximated by s θ ∗ ( t , x ) s_{\theta^*} (t, x) s θ ∗ ( t , x ) .
Let’s consider f ( x ) = − α x f(x) = -\alpha x f ( x ) = − αx for α ≥ 0 \alpha \geq 0 α ≥ 0 . This framework includes the one of Song and Ermon (2019) ( α > 0 \alpha > 0 α > 0 , p prior ( x ) = N ( x ; 0 , 2 T I ) p_{\text{prior}}(x) = \mathcal{N}(x; 0, 2T \bold{I}) p prior ( x ) = N ( x ; 0 , 2 T I ) ) for which ( X t ) t ∈ [ 0 , T ] (\bold{X}_t)_{t\in[0,T]} ( X t ) t ∈ [ 0 , T ] is simply a Brownian motion. It also includes Ho et al. (2020) ( α = 0 \alpha = 0 α = 0 , p prior ( x ) = N ( x ; 0 , I / α ) p_{\text{prior}}(x) = \mathcal{N}(x; 0, \bold{I} / \alpha) p prior ( x ) = N ( x ; 0 , I / α ) ) for which it is an Ornstein-Uhlenbeck process.
3. General SGM and links with existing works Appendix C.3
General SGM Algorithm Consider the forward process
d X t = f t ( X t ) d t + 2 d B t \text{d} \bold{X}_t = f_t(\bold{X}_t) \text{d}t + \sqrt{2} \text{d} \bold{B}_t d X t = f t ( X t ) d t + 2 d B t The discretization gives:
X k + 1 = X k + γ k + 1 f k ( X k ) + 2 γ k + 1 Z k + 1 X_{k+1} = X_k + \gamma_{k+1} f_k(X_k) + \sqrt{2 \gamma_{k+1}} Z_{k+1} X k + 1 = X k + γ k + 1 f k ( X k ) + 2 γ k + 1 Z k + 1 In general, we don’t have that p ( x k ∣ x 0 ) p(x_k | x_0) p ( x k ∣ x 0 ) is a Gaussian density. However, we can obtain that for any x ∈ R d x \in \mathbb{R}^d x ∈ R d ,
p k + 1 ( x ) = ( 4 π γ k + 1 ) − d / 2 ∫ R d p k ( x ~ ) exp [ − ∣ ∣ T k + 1 ( x ~ ) − x ∣ ∣ 2 / ( 4 γ k + 1 ) ] d x ~ p_{k+1}(x) = (4 \pi \gamma_{k+1})^{-d / 2} \int_{\mathbb{R}^d} p_k(\tilde{x}) \exp [- ||\mathcal{T}_{k+1}(\tilde{x}) - x||^2 / (4 \gamma_{k+1})] \text{d} \tilde{x} p k + 1 ( x ) = ( 4 π γ k + 1 ) − d /2 ∫ R d p k ( x ~ ) exp [ − ∣∣ T k + 1 ( x ~ ) − x ∣ ∣ 2 / ( 4 γ k + 1 )] d x ~ with T k + 1 ( x ) = x ~ + γ k + 1 f k ( x ~ ) \mathcal{T}_{k+1}(x) = \tilde{x} + \gamma_{k+1}f_k(\tilde{x}) T k + 1 ( x ) = x ~ + γ k + 1 f k ( x ~ ) . After some math, we can obtain:
∇ log p k + 1 ( x ) = − ( 2 γ k + 1 ) 1 / 2 E [ Z k + 1 ∣ X k + 1 = x ] \nabla \log p_{k+1} (x) = -(2 \gamma_{k+1})^{1/2} \mathbb{E} [Z_{k+1} | X_{k+1} = x] ∇ log p k + 1 ( x ) = − ( 2 γ k + 1 ) 1/2 E [ Z k + 1 ∣ X k + 1 = x ] ☝
How is that equation derived?
Denote Z ~ k + 1 = 2 γ k + 1 Z k + 1 = N ( 0 , 2 γ k + 1 I ) \tilde{Z}_{k+1} = \sqrt{2 \gamma_{k+1}} Z_{k+1} = \mathcal{N}(0, 2 \gamma_{k+1} \bold{I}) Z ~ k + 1 = 2 γ k + 1 Z k + 1 = N ( 0 , 2 γ k + 1 I ) , then,
p k + 1 ( x ) = ∫ p k ( x ~ ) N ( x − T k + 1 ( x ~ ) ; 0 , 2 γ k + 1 I ) d x ~ = ∫ p k ( x ~ ) ( 4 π γ k + 1 ) − d / 2 exp [ − ∣ ∣ T k + 1 ( x ~ ) − x ∣ ∣ 2 / ( 4 γ k + 1 ) ] d x ~ \begin{align*}
p_{k+1}(x) &= \int p_k(\tilde{x}) \mathcal{N}(x - \mathcal{T}_{k+1}(\tilde{x}); 0, 2 \gamma_{k+1} \bold{I}) \text{d} \tilde{x} \\
&= \int p_k(\tilde{x}) (4\pi \gamma_{k+1})^{-d / 2} \exp [-||\mathcal{T}_{k+1}(\tilde{x}) - x||^2 / (4 \gamma_{k+1})] \text{d} \tilde{x}
\end{align*} p k + 1 ( x ) = ∫ p k ( x ~ ) N ( x − T k + 1 ( x ~ ) ; 0 , 2 γ k + 1 I ) d x ~ = ∫ p k ( x ~ ) ( 4 π γ k + 1 ) − d /2 exp [ − ∣∣ T k + 1 ( x ~ ) − x ∣ ∣ 2 / ( 4 γ k + 1 )] d x ~ Take gradient on both side of the equation, and change the order of gradient and integration, we get:
∇ p k + 1 ( x ) = ∫ p k ( x ~ ) T k + 1 ( x ~ ) − x 2 γ k + 1 ( 4 π γ k + 1 ) − d / 2 exp [ − ∣ ∣ T k + 1 ( x ~ ) − x ∣ ∣ 2 / ( 4 γ k + 1 ) ] d x ~ \begin{align*}
\nabla p_{k+1}(x) &= \int p_k(\tilde{x}) \frac{\mathcal{T}_{k+1}(\tilde{x}) - x}{2 \gamma_{k+1}} (4\pi \gamma_{k+1})^{-d / 2} \exp [-||\mathcal{T}_{k+1}(\tilde{x}) - x||^2 / (4 \gamma_{k+1})] \text{d} \tilde{x}
\end{align*} ∇ p k + 1 ( x ) = ∫ p k ( x ~ ) 2 γ k + 1 T k + 1 ( x ~ ) − x ( 4 π γ k + 1 ) − d /2 exp [ − ∣∣ T k + 1 ( x ~ ) − x ∣ ∣ 2 / ( 4 γ k + 1 )] d x ~ Since ∇ p k + 1 ( x ) = p k + 1 ( x ) ∇ log p k + 1 ( x ) \nabla p_{k+1}(x) = p_{k+1}(x) \nabla \log p_{k+1}(x) ∇ p k + 1 ( x ) = p k + 1 ( x ) ∇ log p k + 1 ( x ) , we get derive:
[ 2 γ k + 1 p k + 1 ( x ) ] ∇ log p k + 1 ( x ) = ∫ p k ( x ~ ) [ T k + 1 ( x ~ ) − x ] ( 4 π γ k + 1 ) − d / 2 exp [ − ∣ ∣ T k + 1 ( x ~ ) − x ∣ ∣ 2 / ( 4 γ k + 1 ) ] d x ~ 2 γ k + 1 ∇ log p k + 1 ( x ) = ∫ p k ( x ~ ) [ T k + 1 ( x ~ ) − x ] exp [ − ∣ ∣ T k + 1 ( x ~ ) − x ∣ ∣ 2 / ( 4 γ k + 1 ) ] d x ~ ∫ p k ( x ~ ) exp [ − ∣ ∣ T k + 1 ( x ~ ) − x ∣ ∣ 2 / ( 4 γ k + 1 ) ] d x ~ \begin{align*}
[2\gamma_{k+1} p_{k+1}(x)] \nabla \log p_{k+1}(x) &= \int p_k(\tilde{x}) [\mathcal{T}_{k+1}(\tilde{x}) - x] (4\pi \gamma_{k+1})^{-d / 2} \exp [-||\mathcal{T}_{k+1}(\tilde{x}) - x||^2 / (4 \gamma_{k+1})] \text{d} \tilde{x} \\
2 \gamma_{k+1} \nabla \log p_{k+1}(x) &= \frac{\int p_k(\tilde{x}) [\mathcal{T}_{k+1}(\tilde{x}) - x] \exp [-||\mathcal{T}_{k+1}(\tilde{x}) - x||^2 / (4 \gamma_{k+1})] \text{d} \tilde{x}}{\int p_k(\tilde{x}) \exp [-||\mathcal{T}_{k+1}(\tilde{x}) - x||^2 / (4 \gamma_{k+1})] \text{d} \tilde{x}}
\end{align*} [ 2 γ k + 1 p k + 1 ( x )] ∇ log p k + 1 ( x ) 2 γ k + 1 ∇ log p k + 1 ( x ) = ∫ p k ( x ~ ) [ T k + 1 ( x ~ ) − x ] ( 4 π γ k + 1 ) − d /2 exp [ − ∣∣ T k + 1 ( x ~ ) − x ∣ ∣ 2 / ( 4 γ k + 1 )] d x ~ = ∫ p k ( x ~ ) exp [ − ∣∣ T k + 1 ( x ~ ) − x ∣ ∣ 2 / ( 4 γ k + 1 )] d x ~ ∫ p k ( x ~ ) [ T k + 1 ( x ~ ) − x ] exp [ − ∣∣ T k + 1 ( x ~ ) − x ∣ ∣ 2 / ( 4 γ k + 1 )] d x ~ Now we inspect the conditional probability p k + 1 ∣ k p_{k+1|k} p k + 1∣ k . Note that:
p k + 1 ∣ k ( x ∣ x ~ ) = p Z k + 1 ( 1 2 γ k + 1 [ x − T k + 1 ( x ~ ) ] ) = exp [ − ∣ ∣ T k + 1 ( x ~ ) − x ∣ ∣ 2 / ( 4 γ k + 1 ) ] \begin{align*}
p_{k+1|k}(x|\tilde{x}) &= p_{Z_{k+1}}(\frac{1}{\sqrt{2 \gamma_{k+1}}}[x - \mathcal{T}_{k+1}(\tilde{x})])\\
&= \exp[- ||\mathcal{T}_{k+1}(\tilde{x}) - x||^2 / (4 \gamma_{k+1})]
\end{align*} p k + 1∣ k ( x ∣ x ~ ) = p Z k + 1 ( 2 γ k + 1 1 [ x − T k + 1 ( x ~ )]) = exp [ − ∣∣ T k + 1 ( x ~ ) − x ∣ ∣ 2 / ( 4 γ k + 1 )] Then, by Bayes’ Formula,
p k ∣ k + 1 ( x ~ ∣ x ) = p k + 1 ∣ k ( x ∣ x ~ ) p k ( x ~ ) p k + 1 ( x ) = 1 p k + 1 ( x ) p k ( x ~ ) ⋅ exp [ − ∣ ∣ T k + 1 ( x ~ ) − x ∣ ∣ 2 / ( 4 γ k + 1 ) ] \begin{align*}
p_{k|k+1}(\tilde{x} | x) &= \frac{p_{k+1|k}(x|\tilde{x})p_k(\tilde{x})}{p_{k+1}(x)}\\
&=\frac{1}{p_{k+1}(x)} p_k(\tilde{x}) \cdot \exp[- ||\mathcal{T}_{k+1}(\tilde{x}) - x||^2 / (4 \gamma_{k+1})]
\end{align*} p k ∣ k + 1 ( x ~ ∣ x ) = p k + 1 ( x ) p k + 1∣ k ( x ∣ x ~ ) p k ( x ~ ) = p k + 1 ( x ) 1 p k ( x ~ ) ⋅ exp [ − ∣∣ T k + 1 ( x ~ ) − x ∣ ∣ 2 / ( 4 γ k + 1 )] Therefore, the score function can be written as:
∇ log p k + 1 ( x ) = 1 2 γ k + 1 ∫ [ T k + 1 ( x ~ ) − x ] p k ( x ~ ) exp [ − ∣ ∣ T k + 1 ( x ~ ) − x ∣ ∣ 2 / ( 4 γ k + 1 ) ] / p k + 1 ( x ) d x ~ ∫ p k ( x ~ ) exp [ − ∣ ∣ T k + 1 ( x ~ ) − x ∣ ∣ 2 / ( 4 γ k + 1 ) ] / p k + 1 ( x ) d x ~ = 1 2 γ k + 1 ∫ [ T k + 1 ( x ~ ) − x ] p k ∣ k + 1 ( x ~ ∣ x ) d x ~ = 1 2 γ k + 1 E X k ∼ p k ∣ k + 1 [ T k + 1 ( X k ) − X k + 1 ∣ X k + 1 = x ] = 1 2 γ k + 1 E Z k + 1 ∼ N ( 0 , I ) [ − 2 γ k + 1 Z k + 1 ∣ X k + 1 = x ] = − ( 2 γ k + 1 ) − 1 2 ⋅ E [ Z k + 1 ∣ X k + 1 = x ] \begin{align*}
\nabla \log p_{k+1}(x) &= \frac{1}{2\gamma_{k+1}} \frac{\int [\mathcal{T}_{k+1}(\tilde{x}) - x] p_k(\tilde{x}) \exp [-||\mathcal{T}_{k+1}(\tilde{x}) - x||^2 / (4 \gamma_{k+1})] / p_{k+1}(x) \text{d} \tilde{x}}{\int p_k(\tilde{x}) \exp [-||\mathcal{T}_{k+1}(\tilde{x}) - x||^2 / (4 \gamma_{k+1})] / p_{k+1}(x) \text{d} \tilde{x}} \\
&= \frac{1}{2\gamma_{k+1}} \int [\mathcal{T}_{k+1}(\tilde{x}) - x] p_{k|k+1}(\tilde{x} | x) \text{d} \tilde{x} \\
&= \frac{1}{2\gamma_{k+1}} \mathbb{E}_{X_k \sim p_{k|k+1}} [\mathcal{T}_{k+1}(X_k) - X_{k+1} | X_{k+1} = x] \\
&= \frac{1}{2\gamma_{k+1}} \mathbb{E}_{Z_{k+1} \sim \mathcal{N}(0, I)} [-\sqrt{2\gamma_{k+1}} Z_{k+1} | X_{k+1} = x] \\
&= - (2 \gamma_{k+1})^{-\frac{1}{2}} \cdot \mathbb{E}[Z_{k+1} | X_{k+1} = x]
\end{align*} ∇ log p k + 1 ( x ) = 2 γ k + 1 1 ∫ p k ( x ~ ) exp [ − ∣∣ T k + 1 ( x ~ ) − x ∣ ∣ 2 / ( 4 γ k + 1 )] / p k + 1 ( x ) d x ~ ∫ [ T k + 1 ( x ~ ) − x ] p k ( x ~ ) exp [ − ∣∣ T k + 1 ( x ~ ) − x ∣ ∣ 2 / ( 4 γ k + 1 )] / p k + 1 ( x ) d x ~ = 2 γ k + 1 1 ∫ [ T k + 1 ( x ~ ) − x ] p k ∣ k + 1 ( x ~ ∣ x ) d x ~ = 2 γ k + 1 1 E X k ∼ p k ∣ k + 1 [ T k + 1 ( X k ) − X k + 1 ∣ X k + 1 = x ] = 2 γ k + 1 1 E Z k + 1 ∼ N ( 0 , I ) [ − 2 γ k + 1 Z k + 1 ∣ X k + 1 = x ] = − ( 2 γ k + 1 ) − 2 1 ⋅ E [ Z k + 1 ∣ X k + 1 = x ]